{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Copy of LibreASR-LM.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "dwUtLo0BCWh-"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "cETgM72O9Rtb"
      },
      "source": [
        "!nvidia-smi"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QQdU-qIxAfCb"
      },
      "source": [
        "# import time\n",
        "# time.sleep(10000)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UpubicuyJmCZ"
      },
      "source": [
        "# Language Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iB2DmDXqKFhk"
      },
      "source": [
        "First, upload\n",
        "\n",
        "* `corpus.txt`\n",
        "\n",
        "* `tokenizer.yttm-model`\n",
        "\n",
        "of language you want to train"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8mLHMsU1N3X4"
      },
      "source": [
        "## Config"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zkKJ2rlIN69I"
      },
      "source": [
        "pcent = 1.0\n",
        "bs = 768\n",
        "sl = 64\n",
        "sl_shift = sl\n",
        "corpus = \"corpus.txt\"\n",
        "path_t = \"corpus-train.txt\"\n",
        "path_v = \"corpus-valid.txt\"\n",
        "lang = \"de\"\n",
        "epochs = 8"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MiAfz8lwKYU5"
      },
      "source": [
        "## Requirements"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m5X0ZK4zJmY8"
      },
      "source": [
        "!pip3 install youtokentome\n",
        "!pip3 install -U fastai==2.1.4"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DQ1r_ruV5Eqg"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WkoZo_nCKTOs"
      },
      "source": [
        "import random\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import youtokentome as yttm\n",
        "\n",
        "from fastai.text.all import *"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XK_p8QllTSCg"
      },
      "source": [
        "import fastai\n",
        "torch.__version__, fastai.__version__"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "31FmfZNfKfho"
      },
      "source": [
        "# load tokenizer\n",
        "tokenizer = yttm.BPE(model=\"tokenizer.yttm-model\")\n",
        "tokenizer.decode(tokenizer.encode(\"hello i'm joe äüö\"))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4HNK0b8UK0jg"
      },
      "source": [
        "# create a model\n",
        "class LM(nn.Module):\n",
        "  def __init__(self, vocab_sz, embed_sz, hidden_sz, num_layers, bs, device=\"cuda:0\", p=0.2):\n",
        "    super().__init__()\n",
        "    self.embed = nn.Embedding(vocab_sz, embed_sz, padding_idx=0)\n",
        "    self.rnn = nn.LSTM(embed_sz, hidden_sz, batch_first=True, num_layers=num_layers)\n",
        "    self.drop = nn.Dropout(p)\n",
        "    self.linear = nn.Linear(hidden_sz, vocab_sz)\n",
        "    self.h = [torch.zeros(num_layers, bs, hidden_sz, device=device) for _ in range(2)]\n",
        "    if embed_sz == hidden_sz:\n",
        "      # tie\n",
        "      self.linear.weight = self.embed.weight\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = self.embed(x)\n",
        "    raw, h = self.rnn(x, self.h)\n",
        "    out = self.drop(raw)\n",
        "    self.h = [h_.detach() for h_ in h]\n",
        "    return F.log_softmax(self.linear(out), dim=-1), raw, out\n",
        "\n",
        "  def reset(self):\n",
        "    for h in self.h: h.zero_()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Qs5rXFA3j1u"
      },
      "source": [
        "## Data Pipeline"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QiQF_JqrOLyi"
      },
      "source": [
        "def split_file(f,out1,out2,percentage=0.75,isShuffle=True,seed=123):\n",
        "    \"\"\"Splits a file in 2 given the `percentage` to go in the large file.\"\"\"\n",
        "    random.seed(seed)\n",
        "    with open(f, 'r',encoding=\"utf-8\") as fin, open(out1, 'w') as foutBig, open(out2, 'w') as foutSmall:\n",
        "\n",
        "        nLines = sum(1 for line in fin) # if didn't count you could only approximate the percentage\n",
        "        fin.seek(0)\n",
        "        nTrain = int(nLines*percentage) \n",
        "        nValid = nLines - nTrain\n",
        "\n",
        "        i = 0\n",
        "        for line in fin:\n",
        "            r = random.random() if isShuffle else 0 # so that always evaluated to true when not isShuffle\n",
        "            if (i < nTrain and r < percentage) or (nLines - i > nValid):\n",
        "                foutBig.write(line)\n",
        "                i += 1\n",
        "            else:\n",
        "                foutSmall.write(line)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dJabUEWj4XRD"
      },
      "source": [
        "split_file(corpus, path_t, path_v, percentage=0.8)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lndzEMyLOXG5"
      },
      "source": [
        "!ls -lah"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ucJYpvU7Oq1Q"
      },
      "source": [
        "def read(f, pcent=0.1):\n",
        "  with open(f, 'r') as f:\n",
        "    lines = f.readlines()\n",
        "    return lines[:int(len(lines) * pcent)]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zKRp9TySOf38"
      },
      "source": [
        "train_txt = read(path_t, pcent=pcent); train_txt[0][:80]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XtIS1lJSOpXO"
      },
      "source": [
        "valid_txt = read(path_v, pcent=pcent); valid_txt[0][-80:]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9LNdWBRmVFD9"
      },
      "source": [
        "def numericalize_label(lines):\n",
        "  nums = tokenizer.encode(\"\".join(lines).replace(\"\\n\", \"\"))\n",
        "  return L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))\n",
        "         for i in range(0,len(nums)-sl-1,sl))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7HHPeSs_W6ZX"
      },
      "source": [
        "train_nums = numericalize_label(train_txt)\n",
        "train_nums[:5]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8jBmgTHtVtBr"
      },
      "source": [
        "valid_nums = numericalize_label(valid_txt)\n",
        "valid_nums[:5]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PrjH9aUkUoqF"
      },
      "source": [
        "dls = DataLoaders.from_dsets(train_nums, valid_nums, bs=bs, drop_last=True, shuffle=True)\n",
        "dls"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AfV2h12EXuzr"
      },
      "source": [
        "# loss\n",
        "def loss_func(inp, targ):\n",
        "    return F.cross_entropy(inp.view(-1, 2048), targ.view(-1))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iVq-J5iQX5SV"
      },
      "source": [
        "# learner\n",
        "model = LM(2048, 768, 768, 4, bs, p=0.3).cuda()\n",
        "print(f\"{sum(p.numel() for p in model.parameters()) // 1_000_000}M params\")\n",
        "cbs = [CudaCallback, ModelResetter, RNNRegularizer(alpha=2., beta=1.)]\n",
        "metrics = [accuracy, perplexity]\n",
        "learn = Learner(dls, model, loss_func=loss_func, metrics=metrics, cbs=cbs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p4htJSddpgll"
      },
      "source": [
        "# best:\n",
        "#  - de: valid=3.72 | perplex=41.30 | 8eps@1e-2@sp,reg,p=0.3,wd=0.1,small,Adam\n",
        "#  - en: valid=3.56 | perplex=35.00 | 8eps@1e-2@sp,reg,p=0.3,wd=0.1,small,Adam\n",
        "learn.fit_one_cycle(epochs, 1e-2, wd=0.1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NpMPd_Go7cXT"
      },
      "source": [
        "## Save trained LM"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KsbMBrK43FaD"
      },
      "source": [
        "model.cpu()\n",
        "model.eval()\n",
        "torch.save(model.state_dict(), f\"lm.pth\")\n",
        "!ls -lah lm.pth"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e66W7D0D4Cf6"
      },
      "source": [
        "import time\n",
        "time.sleep(99999)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dwUtLo0BCWh-"
      },
      "source": [
        "## Copy stuff to drive"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YzX9iRYui3Rm"
      },
      "source": [
        "# To Do\n",
        "\n",
        "* generation function (with len)\n",
        "\n",
        "* better LM\n",
        "\n",
        "    * TCN\n",
        "\n",
        "    * BatchNorm\n",
        "\n",
        "    * ranger optim\n",
        "\n",
        "    * more data\n",
        "\n",
        "    * larger sl\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tB2yddoujD3a"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}